{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y5OeTiryEcoX"
   },
   "source": [
    "# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.\n",
    "\n",
    "We have adopted the Gemma notebook from Google Deepmind to use HuggingFace's libraries, added support for doing **model parallel training** and simplified the setup."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5m81VQOqEcoX"
   },
   "source": [
    "## Setup "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import importlib\n",
    "def import_local_module(module_path: str):\n",
    "    sys.path.append('')\n",
    "    module = importlib.import_module(module_path)\n",
    "    return importlib.reload(module)\n",
    "\n",
    "# Imports felafax trainer_engine\n",
    "setup = import_local_module(\"trainer_engine.setup\")\n",
    "setup.setup_environment()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install --upgrade kagglehub -q\n",
    "!pip install ipywidgets -q\n",
    "!pip install torch --index-url https://download.pytorch.org/whl/cpu -q\n",
    "!pip install git+https://github.com/felafax/gemma.git -q\n",
    "!pip install qax -q\n",
    "!pip install jax-lorax -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "globals().update(setup.setup_imports())\n",
    "\n",
    "utils = import_local_module(\"trainer_engine.utils\")\n",
    "training_pipeline = import_local_module(\"trainer_engine.training_pipeline\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 0: Input your HF username, token and download model weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Select the base model you want to fine-tune 👇"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "supported_models = [\n",
    "    \"gemma-2-2b-it\",  # 2b\n",
    "    \"gemma-2-9b-it\",  # 9b\n",
    "]\n",
    "\n",
    "MODEL_NAME=\"gemma-2-9b-it\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Input your HuggingFace🤗 username and token below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INPUT: Please provide your HUGGINGFACE_USERNAME:  \n",
      "INPUT: Please provide your HUGGINGFACE_TOKEN:  \n"
     ]
    }
   ],
   "source": [
    "hf_model_name = f\"felafax/{MODEL_NAME}-JAX\"\n",
    "HUGGINGFACE_USERNAME = input(\"INPUT: Please provide your HUGGINGFACE_USERNAME: \")\n",
    "HUGGINGFACE_TOKEN = input(\"INPUT: Please provide your HUGGINGFACE_TOKEN: \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "# Downloads the model to disk.\n",
    "from huggingface_hub import snapshot_download\n",
    "ckpt_path = snapshot_download(repo_id=hf_model_name, token=HUGGINGFACE_TOKEN)\n",
    "vocab_path = os.path.join(ckpt_path, 'tokenizer.model')\n",
    "model_path = os.path.join(ckpt_path, re.sub(r'gemma-(\\d+)-', r'gemma\\1-', MODEL_NAME))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loads the downloaded model.\n",
    "params = {\n",
    "    \"params\": params_lib.load_and_format_params(model_path)['transformer']\n",
    "}\n",
    "model_config = transformer_lib.TransformerConfig.from_params(params={\"transformer\": params[\"params\"]}, cache_size=30)\n",
    "model = transformer_lib.Transformer(config=model_config)\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    hf_model_name, \n",
    "    token=HUGGINGFACE_TOKEN\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: prepare the dataset\n",
    "\n",
    "For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.\n",
    "\n",
    "It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(*, tokenizer, batch_size=1, max_length=32, max_examples=None):\n",
    "    # Define Alpaca prompt template\n",
    "    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "    \n",
    "    ### Instruction: {}\n",
    "    \n",
    "    ### Input: {}\n",
    "    \n",
    "    ### Response: {}\"\"\"\n",
    "    \n",
    "    EOS_TOKEN = tokenizer.eos_token\n",
    "    \n",
    "    # Defines formatting function.\n",
    "    def _format_prompts(examples):\n",
    "        instructions = examples[\"instruction\"]\n",
    "        inputs = examples[\"input\"]\n",
    "        outputs = examples[\"output\"]\n",
    "        texts = []\n",
    "        for instruction, input, output in zip(instructions, inputs, outputs):\n",
    "            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n",
    "            texts.append(text)\n",
    "        return {\"text\": texts}\n",
    "\n",
    "    def _tokenize(examples):\n",
    "        tokenized = tokenizer(examples[\"text\"], truncation=True, padding=\"max_length\", max_length=max_length+1)\n",
    "        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]\n",
    "        tokenized['target_mask'] = [input_id[:-1] for input_id in tokenized['attention_mask']]\n",
    "        return {\n",
    "            'input_tokens': tokenized['input_ids'],\n",
    "            'target_mask': tokenized['target_mask']\n",
    "        }\n",
    "\n",
    "    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:\n",
    "        \"\"\"\n",
    "        Collates batch items and converts PyTorch tensors to JAX arrays.\n",
    "        Applies default_data_collator, then converts tensors to JAX format.\n",
    "        \"\"\"\n",
    "        collated = default_data_collator(batch)\n",
    "        jax_batch = {}\n",
    "        for key, value in collated.items():\n",
    "            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value\n",
    "        \n",
    "        return jax_batch\n",
    "\n",
    "    # Load and preprocess the dataset\n",
    "    dataset = load_dataset(\"yahma/alpaca-cleaned\", split=\"train\")\n",
    "    if max_examples:\n",
    "        dataset = dataset.select(range(max_examples))\n",
    "    dataset = dataset.map(_format_prompts, batched=True)\n",
    "\n",
    "    # Create train and test dataset.\n",
    "    ds = dataset.train_test_split(test_size=0.15)\n",
    "    for split in ['train', 'test']:\n",
    "        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "\n",
    "    # Create DataLoaders\n",
    "    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)\n",
    "    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)\n",
    "    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)\n",
    "\n",
    "    return train_dataloader, test_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Uncomment below code ⬇️ if you'd like to run and test 💯 your dataset pipeline.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def test_dataset_pipeline(tokenizer):\n",
    "#     \"\"\"Print shapes of first batch to verify dataset pipeline.\"\"\"\n",
    "#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, max_length=512)\n",
    "#     batch = next(iter(train_loader))\n",
    "#     print(\"Input tokens shape:\", batch['input_tokens'].shape)\n",
    "#     print(\"Target mask shape:\", batch['target_mask'].shape)\n",
    "# test_dataset_pipeline(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Train the model by configuring the hyperparameters below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "@chex.dataclass(frozen=True)\n",
    "class TrainingConfig:\n",
    "  learning_rate: float = 1e-4\n",
    "  num_epochs: int = 1\n",
    "  max_steps: int | None = 40  # max number of training steps (**set to None** to train for full num_epochs)\n",
    "\n",
    "  # Dataset config\n",
    "  batch_size: int = 32\n",
    "  max_length: int = 64  # max seq lenght of tokens in input batch\n",
    "  dataset_size_limit: int | None = None    # limit on number of dataset examples for testing (**set to None** to use full dataset)\n",
    "\n",
    "  # Misc config\n",
    "  print_every_n_steps: int = 1\n",
    "\n",
    "training_cfg = TrainingConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1f5de4dbe9041be8ffeb058a2101894",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/43996 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1abc2e4a24b4c338b9f112df73ff650",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/7764 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_dataloader, val_dataloader = get_dataset(tokenizer=tokenizer, max_length=training_cfg.max_length, max_examples=training_cfg.dataset_size_limit)\n",
    "optimizer = optax.sgd(training_cfg.learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sharding model acorss 4  devices.\n"
     ]
    }
   ],
   "source": [
    "# Sets up the device mesh for sharding the model across TPU cores and to do model parallel training.\n",
    "devices = jax.devices()\n",
    "device_count = len(devices)\n",
    "device_mesh = mesh_utils.create_device_mesh((1, device_count, 1))\n",
    "mesh = Mesh(devices=device_mesh, axis_names=('data', 'model', 'replica'))\n",
    "print(\"Sharding model acorss\", device_count, \" devices.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE**: The **time-to-first step of training will be slow** because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using the compiled and cached graph, leveraging the full power of all TPU cores for accelerated training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 0, Train Loss: 3.6503\n",
      "Step 1, Train Loss: 2.4262\n",
      "Step 2, Train Loss: 3.2592\n",
      "Step 3, Train Loss: 2.2910\n",
      "Step 4, Train Loss: 2.3543\n",
      "Step 5, Train Loss: 2.2851\n",
      "Step 6, Train Loss: 2.1437\n",
      "Step 7, Train Loss: 2.2376\n",
      "Step 8, Train Loss: 1.8792\n",
      "Step 9, Train Loss: 1.9982\n",
      "Step 10, Train Loss: 2.2012\n",
      "Step 11, Train Loss: 2.0664\n",
      "Step 12, Train Loss: 1.9715\n",
      "Step 13, Train Loss: 2.1281\n",
      "Step 14, Train Loss: 2.7430\n",
      "Step 15, Train Loss: 2.1233\n",
      "Step 16, Train Loss: 1.8954\n",
      "Step 17, Train Loss: 2.1002\n",
      "Step 18, Train Loss: 1.9138\n",
      "Step 19, Train Loss: 1.9351\n",
      "Step 20, Train Loss: 2.3138\n",
      "Step 21, Train Loss: 1.7246\n",
      "Step 22, Train Loss: 1.4195\n",
      "Step 23, Train Loss: 1.4904\n",
      "Step 24, Train Loss: 1.5521\n",
      "Step 25, Train Loss: 1.7244\n",
      "Step 26, Train Loss: 1.5133\n",
      "Step 27, Train Loss: 1.5379\n",
      "Step 28, Train Loss: 1.3721\n",
      "Step 29, Train Loss: 1.4148\n",
      "Step 30, Train Loss: 1.3466\n",
      "Step 31, Train Loss: 1.5437\n",
      "Step 32, Train Loss: 1.3475\n",
      "Step 33, Train Loss: 1.4616\n",
      "Step 34, Train Loss: 1.4873\n",
      "Step 35, Train Loss: 1.6208\n",
      "Step 36, Train Loss: 1.8785\n",
      "Step 37, Train Loss: 1.9606\n",
      "Step 38, Train Loss: 1.9625\n",
      "Step 39, Train Loss: 1.5795\n",
      "Step 40, Train Loss: 1.4650\n",
      "Training complete!\n"
     ]
    }
   ],
   "source": [
    "state = training_pipeline.train_loop(model=model,\n",
    "                    tokenizer=tokenizer,\n",
    "                    params=params,\n",
    "                    optimizer=optimizer,\n",
    "                    train_dataloader=train_dataloader,\n",
    "                    training_cfg=training_cfg, \n",
    "                    mesh = mesh)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "private_outputs": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
